{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this line in Colab to install the package if it is\n",
    "# not already installed.\n",
    "!pip install git+https://github.com/openai/glide-text2im"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "\n",
    "from IPython.display import display\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import torch as th\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from glide_text2im.download import load_checkpoint\n",
    "from glide_text2im.model_creation import (\n",
    "    create_model_and_diffusion,\n",
    "    model_and_diffusion_defaults,\n",
    "    model_and_diffusion_defaults_upsampler\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This notebook supports both CPU and GPU.\n",
    "# On CPU, generating one sample may take on the order of 20 minutes.\n",
    "# On a GPU, it should be under a minute.\n",
    "\n",
    "has_cuda = th.cuda.is_available()\n",
    "device = th.device('cpu' if not has_cuda else 'cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create base model.\n",
    "options = model_and_diffusion_defaults()\n",
    "options['inpaint'] = True\n",
    "options['use_fp16'] = has_cuda\n",
    "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
    "model, diffusion = create_model_and_diffusion(**options)\n",
    "model.eval()\n",
    "if has_cuda:\n",
    "    model.convert_to_fp16()\n",
    "model.to(device)\n",
    "model.load_state_dict(load_checkpoint('base-inpaint', device))\n",
    "print('total base parameters', sum(x.numel() for x in model.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create upsampler model.\n",
    "options_up = model_and_diffusion_defaults_upsampler()\n",
    "options_up['inpaint'] = True\n",
    "options_up['use_fp16'] = has_cuda\n",
    "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n",
    "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n",
    "model_up.eval()\n",
    "if has_cuda:\n",
    "    model_up.convert_to_fp16()\n",
    "model_up.to(device)\n",
    "model_up.load_state_dict(load_checkpoint('upsample-inpaint', device))\n",
    "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_images(batch: th.Tensor):\n",
    "    \"\"\" Display a batch of images inline. \"\"\"\n",
    "    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
    "    reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
    "    display(Image.fromarray(reshaped.numpy()))\n",
    "\n",
    "def read_image(path: str, size: int = 256) -> Tuple[th.Tensor, th.Tensor]:\n",
    "    pil_img = Image.open(path).convert('RGB')\n",
    "    pil_img = pil_img.resize((size, size), resample=Image.BICUBIC)\n",
    "    img = np.array(pil_img)\n",
    "    return th.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sampling parameters\n",
    "prompt = \"a corgi in a field\"\n",
    "batch_size = 1\n",
    "guidance_scale = 5.0\n",
    "\n",
    "# Tune this parameter to control the sharpness of 256x256 images.\n",
    "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n",
    "upsample_temp = 0.997\n",
    "\n",
    "# Source image we are inpainting\n",
    "source_image_256 = read_image('grass.png', size=256)\n",
    "source_image_64 = read_image('grass.png', size=64)\n",
    "\n",
    "# The mask should always be a boolean 64x64 mask, and then we\n",
    "# can upsample it for the second stage.\n",
    "source_mask_64 = th.ones_like(source_image_64)[:, :1]\n",
    "source_mask_64[:, :, 20:] = 0\n",
    "source_mask_256 = F.interpolate(source_mask_64, (256, 256), mode='nearest')\n",
    "\n",
    "# Visualize the image we are inpainting\n",
    "show_images(source_image_256 * source_mask_256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##############################\n",
    "# Sample from the base model #\n",
    "##############################\n",
    "\n",
    "# Create the text tokens to feed to the model.\n",
    "tokens = model.tokenizer.encode(prompt)\n",
    "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
    "    tokens, options['text_ctx']\n",
    ")\n",
    "\n",
    "# Create the classifier-free guidance tokens (empty)\n",
    "full_batch_size = batch_size * 2\n",
    "uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n",
    "    [], options['text_ctx']\n",
    ")\n",
    "\n",
    "# Pack the tokens together into model kwargs.\n",
    "model_kwargs = dict(\n",
    "    tokens=th.tensor(\n",
    "        [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n",
    "    ),\n",
    "    mask=th.tensor(\n",
    "        [mask] * batch_size + [uncond_mask] * batch_size,\n",
    "        dtype=th.bool,\n",
    "        device=device,\n",
    "    ),\n",
    "\n",
    "    # Masked inpainting image\n",
    "    inpaint_image=(source_image_64 * source_mask_64).repeat(full_batch_size, 1, 1, 1).to(device),\n",
    "    inpaint_mask=source_mask_64.repeat(full_batch_size, 1, 1, 1).to(device),\n",
    ")\n",
    "\n",
    "# Create an classifier-free guidance sampling function\n",
    "def model_fn(x_t, ts, **kwargs):\n",
    "    half = x_t[: len(x_t) // 2]\n",
    "    combined = th.cat([half, half], dim=0)\n",
    "    model_out = model(combined, ts, **kwargs)\n",
    "    eps, rest = model_out[:, :3], model_out[:, 3:]\n",
    "    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n",
    "    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n",
    "    eps = th.cat([half_eps, half_eps], dim=0)\n",
    "    return th.cat([eps, rest], dim=1)\n",
    "\n",
    "def denoised_fn(x_start):\n",
    "    # Force the model to have the exact right x_start predictions\n",
    "    # for the part of the image which is known.\n",
    "    return (\n",
    "        x_start * (1 - model_kwargs['inpaint_mask'])\n",
    "        + model_kwargs['inpaint_image'] * model_kwargs['inpaint_mask']\n",
    "    )\n",
    "\n",
    "# Sample from the base model.\n",
    "model.del_cache()\n",
    "samples = diffusion.p_sample_loop(\n",
    "    model_fn,\n",
    "    (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
    "    device=device,\n",
    "    clip_denoised=True,\n",
    "    progress=True,\n",
    "    model_kwargs=model_kwargs,\n",
    "    cond_fn=None,\n",
    "    denoised_fn=denoised_fn,\n",
    ")[:batch_size]\n",
    "model.del_cache()\n",
    "\n",
    "# Show the output\n",
    "show_images(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##############################\n",
    "# Upsample the 64x64 samples #\n",
    "##############################\n",
    "\n",
    "tokens = model_up.tokenizer.encode(prompt)\n",
    "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n",
    "    tokens, options_up['text_ctx']\n",
    ")\n",
    "\n",
    "# Create the model conditioning dict.\n",
    "model_kwargs = dict(\n",
    "    # Low-res image to upsample.\n",
    "    low_res=((samples+1)*127.5).round()/127.5 - 1,\n",
    "\n",
    "    # Text tokens\n",
    "    tokens=th.tensor(\n",
    "        [tokens] * batch_size, device=device\n",
    "    ),\n",
    "    mask=th.tensor(\n",
    "        [mask] * batch_size,\n",
    "        dtype=th.bool,\n",
    "        device=device,\n",
    "    ),\n",
    "\n",
    "    # Masked inpainting image.\n",
    "    inpaint_image=(source_image_256 * source_mask_256).repeat(batch_size, 1, 1, 1).to(device),\n",
    "    inpaint_mask=source_mask_256.repeat(batch_size, 1, 1, 1).to(device),\n",
    ")\n",
    "\n",
    "def denoised_fn(x_start):\n",
    "    # Force the model to have the exact right x_start predictions\n",
    "    # for the part of the image which is known.\n",
    "    return (\n",
    "        x_start * (1 - model_kwargs['inpaint_mask'])\n",
    "        + model_kwargs['inpaint_image'] * model_kwargs['inpaint_mask']\n",
    "    )\n",
    "\n",
    "# Sample from the base model.\n",
    "model_up.del_cache()\n",
    "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n",
    "up_samples = diffusion_up.p_sample_loop(\n",
    "    model_up,\n",
    "    up_shape,\n",
    "    noise=th.randn(up_shape, device=device) * upsample_temp,\n",
    "    device=device,\n",
    "    clip_denoised=True,\n",
    "    progress=True,\n",
    "    model_kwargs=model_kwargs,\n",
    "    cond_fn=None,\n",
    "    denoised_fn=denoised_fn,\n",
    ")[:batch_size]\n",
    "model_up.del_cache()\n",
    "\n",
    "# Show the output\n",
    "show_images(up_samples)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "accelerator": "GPU"
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
